from typing import Dict
import torch
import numpy as np
import copy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.common.streaming_replay_buffer import StreamingReplayBuffer
from diffusion_policy.common.sampler import (
    SequenceSampler, get_val_mask, downsample_mask)
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.common.normalize_util import get_image_range_normalizer
import torch.nn.functional as F

class DualArmImageDataset(BaseImageDataset):
    def __init__(self,
            zarr_path, 
            horizon=1,
            n_obs_steps=1,
            pad_before=0,
            pad_after=0,
            seed=42,
            val_ratio=0.0,
            max_train_episodes=None,
            image_size=(224, 224),
            # image_size=(384, 384)
            ):
        
        super().__init__()
        print("lsdgjdfkgh/dfgkh/s;fgkh/s;kfg/;hks/fgks/;lj")
        self.image_size = image_size
        # 更新keys以匹配新的数据结构
        self.replay_buffer = StreamingReplayBuffer.copy_from_path(
            zarr_path, keys=[
                # 'arm1_camera_rgb',
                # 'arm2_camera_rgb',
                'base_camera_rgb',
                'joint_qpos',
                'joint_qvel',
                'gripper_0_qpos',
                'gripper_0_qvel',
                'eef_0_pos',
                'eef_0_quat',
                'gripper_1_qpos',
                'gripper_1_qvel',
                'eef_1_pos',
                'eef_1_quat',
                'obj_state',
                'action'
            ])
            
        val_mask = get_val_mask(
            n_episodes=self.replay_buffer.n_episodes, 
            val_ratio=val_ratio,
            seed=seed)
        train_mask = ~val_mask
        train_mask = downsample_mask(
            mask=train_mask, 
            max_n=max_train_episodes, 
            seed=seed)

        self.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer, 
            sequence_length=horizon,
            pad_before=pad_before, 
            pad_after=pad_after,
            episode_mask=train_mask)
        
        self.train_mask = train_mask
        self.horizon = horizon
        self.pad_before = pad_before
        self.pad_after = pad_after
        self.n_obs_steps = n_obs_steps


    def _process_mask_image_batch(self, images):
        """批量处理图像"""
        # 转换整个批次
        rgb = torch.from_numpy(images[..., :3]).float()  # [T, H, W, 3]
        mask = torch.from_numpy(images[..., 3:]).float() # [T, H, W, 1]
        
        # 处理RGB
        rgb = rgb.permute(0, 3, 1, 2)  # [T, 3, H, W]

        rgb = F.interpolate(
            rgb / 255.0,
            size=self.image_size,
            mode='bilinear',
            align_corners=False
        )

        # 处理mask
        mask = mask.permute(0, 3, 1, 2)  # [T, 1, H, W]
        mask = F.interpolate(
            mask,
            size=self.image_size,
            mode='nearest'
        )

        mask = (mask > 0.5).float()

        # 合并
        combined = torch.cat([rgb, mask], dim=1)  # [T, 4, H, W]
        return combined.numpy()

    def _process_image_batch(self, images):
        """批量处理图像"""
        # 转换整个批次
        rgb = torch.from_numpy(images[..., :3]).float()  # [T, H, W, 3]
        
        # 处理RGB
        rgb = rgb.permute(0, 3, 1, 2)  # [T, 3, H, W]

        rgb = F.interpolate(
            rgb / 255.0,
            size=self.image_size,
            mode='bilinear',
            align_corners=False
        )

        return rgb.numpy()


    
    def _sample_to_data(self, sample):
        T_slice = slice(self.n_obs_steps)
        
        # 处理图像数据
        #sensor1_frames = self._process_image_batch(sample['arm1_camera_rgb'][T_slice])
        #sensor2_frames = self._process_image_batch(sample['arm2_camera_rgb'][T_slice])
        sensor3_frames = self._process_image_batch(sample['base_camera_rgb'][T_slice])
        
        # 处理低维数据
        data = {
            'obs': {
                #'arm1_camera_rgb': sensor1_frames,
                #'arm2_camera_rgb': sensor2_frames,
                'base_camera_rgb': sensor3_frames,
                'joint_qpos': sample['joint_qpos'][T_slice].astype(np.float32),
                'joint_qvel': sample['joint_qvel'][T_slice].astype(np.float32),
                'gripper_0_qpos': sample['gripper_0_qpos'][T_slice].astype(np.float32),
                'gripper_0_qvel': sample['gripper_0_qvel'][T_slice].astype(np.float32),
                'eef_0_pos': sample['eef_0_pos'][T_slice].astype(np.float32),
                'eef_0_quat': sample['eef_0_quat'][T_slice].astype(np.float32),
                'gripper_1_qpos': sample['gripper_1_qpos'][T_slice].astype(np.float32),
                'gripper_1_qvel': sample['gripper_1_qvel'][T_slice].astype(np.float32),
                'eef_1_pos': sample['eef_1_pos'][T_slice].astype(np.float32),
                'eef_1_quat': sample['eef_1_quat'][T_slice].astype(np.float32),
                'obj_state': sample['obj_state'][T_slice].astype(np.float32)
            },
            'action': sample['action'].astype(np.float32)
        }
        return data

    def get_normalizer(self, mode='limits', **kwargs):
        """修改normalizer以适配新的数据结构"""
        data = {
            'action': self.replay_buffer['action'],
            'joint_qpos': self.replay_buffer['joint_qpos'],
            'joint_qvel': self.replay_buffer['joint_qvel'],
            'gripper_0_qpos': self.replay_buffer['gripper_0_qpos'],
            'gripper_0_qvel': self.replay_buffer['gripper_0_qvel'],
            'eef_0_pos': self.replay_buffer['eef_0_pos'],
            'eef_0_quat': self.replay_buffer['eef_0_quat'],
            'gripper_1_qpos': self.replay_buffer['gripper_1_qpos'],
            'gripper_1_qvel': self.replay_buffer['gripper_1_qvel'],
            'eef_1_pos': self.replay_buffer['eef_1_pos'],
            'eef_1_quat': self.replay_buffer['eef_1_quat'],
            'obj_state': self.replay_buffer['obj_state'],
        }
        normalizer = LinearNormalizer()
        normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
        # 添加图像normalizer
        #normalizer['arm1_camera_rgb'] = get_image_range_normalizer()
       # normalizer['arm2_camera_rgb'] = get_image_range_normalizer()
        normalizer['base_camera_rgb'] = get_image_range_normalizer()
        return normalizer

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.sampler.sample_sequence(idx)        
        data = self._sample_to_data(sample)
        torch_data = dict_apply(data, torch.from_numpy)
        return torch_data

    def get_all_actions(self) -> torch.Tensor:
        return torch.from_numpy(self.replay_buffer['action'])

    def __len__(self):
        return len(self.sampler)

